package cn.edu.bjtu.model.word2vec;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.bjtu.word2vec.ApplicationConfig;

import cn.edu.bjtu.abstractimpl.analyzer.AnsjDocumentAnalyzer;
import cn.edu.bjtu.abstractimpl.wordfilters.RegExNumbefFilter;
import cn.edu.bjtu.abstractimpl.wordfilters.SingleWordTermFilter;
import cn.edu.bjtu.classimpl.documententity.Document;
import cn.edu.bjtu.classimpl.documentvec.DocumentVectorImpl;
import cn.edu.bjtu.classimpl.parser.HiveWechatParser;
import cn.edu.bjtu.classimpl.parser.LineParser;
import cn.edu.bjtu.general.math.DenseVector;
import cn.edu.bjtu.general.math.Vector;
import cn.edu.bjtu.interfaces.ILearnerModel;
import cn.edu.bjtu.interfaces.ModelStatus;
import cn.edu.bjtu.interfaces.document.IDocument;
import cn.edu.bjtu.interfaces.parser.Parser;
import cn.edu.bjtu.interfaces.segment.DocumentSegmentation;
import cn.edu.bjtu.interfaces.vector.IDocumentVector;
import cn.edu.bjtu.interfaces.wordfilters.IWord;
import cn.edu.bjtu.interfaces.wordfilters.WordsFilter;
import cn.edu.bjtu.model.word2vec.util.WordInWord2Vec;
import cn.edu.bjtu.tools.FileUtil;

public class Word2VecForTransformingDocs implements ILearnerModel {
	private static final Log LOG = LogFactory
			.getLog(Word2VecForTransformingDocs.class);

	private HashMap<String, float[]> wordMap = new HashMap();
	private int words;
	private int size;
	private int topNSize = 40;
	private static final int MAX_SIZE = 50;
	private static int wordVecDim = 200;
	private static String RESOURCE_MODEL = "";
//	private static Map<String, Double> wordsIDF = new HashMap<String, Double>();
//	private static String FEATURE_PATH = File.separator
//			+ "feature.txt";
	private static Word2VecForTransformingDocs word2Vec = null;
	private DocumentSegmentation docSeg = null;
	private Set<WordsFilter> filtersSet = new HashSet<WordsFilter>();
	CountDownLatch latch;
	AtomicLong cur = new AtomicLong(0L);
	static {
		RESOURCE_MODEL = ApplicationConfig.getInstance().getProperty("word2vec_model_dir")+"word2vecModel";
//		FEATURE_PATH = prop.getProperty("featurepath");
	}

	public static Word2VecForTransformingDocs getInstance(
			DocumentSegmentation docSeg) {
		if (word2Vec == null) {
			word2Vec = new Word2VecForTransformingDocs();
			word2Vec.loadModelFromFile(new File(RESOURCE_MODEL));
			// word2Vec.readIDF(wordsIDF, FEATURE_PATH);
		}
		word2Vec.filtersSet.add(new SingleWordTermFilter());
		word2Vec.filtersSet.add(new RegExNumbefFilter());
		word2Vec.docSeg = docSeg;
		return word2Vec;
	}

	public void fit(Iterator<IDocument> docs) {
		word2Vec.transform(docs);
	}

	public IDocumentVector transform(IDocument doc) {

		Map<String, Double> countWords = new HashMap<String, Double>();
		Vector docVector = new DenseVector(wordVecDim);
		List<IWord> words = new ArrayList<IWord>();
		String segResult = this.docSeg.segment(doc);
		segResult = filterTerms(segResult);
		for (String singword : segResult.split("\\s+")) {
			if (countWords.get(singword) == null)
				countWords.put(singword, Double.valueOf(1.0D));
			else {
				countWords
						.put(singword,
								Double.valueOf(((Double) countWords
										.get(singword)).doubleValue()
										+ 1.0D));
			}
		}
/*
		// 计算tf
		for (Map.Entry<String, Double> entry : countWords
				.entrySet()) {
			countWords
					.put((String) entry.getKey(),
							Double.valueOf(((Double) entry.getValue())
									.doubleValue()
									/ countWords.size()));
		}
		// 计算tf-idf
		for (Map.Entry<String, Double> entry : countWords
				.entrySet()) {
			double idf = 0.0D;
			if (wordsIDF.get(entry.getKey()) == null)
				idf = 1.0E-008D;
			else {
				idf = ((Double) wordsIDF.get(entry.getKey()))
						.doubleValue();
			}
			countWords.put((String) entry.getKey(), Double.valueOf(
					((Double) entry.getValue()).doubleValue() * idf));
		}
*/
		for (Map.Entry<String, Double> entry : countWords
				.entrySet()) {
			IWord word = new WordInWord2Vec((String) entry.getKey(),
					((Double) entry.getValue()).doubleValue());
			words.add(word);
		}
		for (IWord word : words) {
			Vector wordVec = new DenseVector(wordVecDim);
			if (word2Vec.queryWord(word) != null) {
				wordVec = word2Vec.queryWord(word);
			}
			weightedVec((Vector) wordVec, word.getWeight());
			wordToCount(docVector,   wordVec);
		}
		IDocumentVector doucumentVector = new DocumentVectorImpl(
				doc.getID(), doc.getTag(), docVector);
		return doucumentVector;
	}

//	public void readIDF(Map<String, Double> countWords,
//			String featurePath) {
//		BufferedReader br = null;
//		Parser p = new HiveWechatParser();
//		try {
//			br = new BufferedReader(new FileReader(featurePath));
//			String line = "";
//			while ((line = br.readLine()) != null)
//				try {
//					String[] wordInfo = line.split("\t");
//					countWords.put(wordInfo[0], Double.valueOf(
//							Double.parseDouble(wordInfo[2])));
//				} catch (Exception localException) {
//				}
//		} catch (Exception localException1) {
//		}
//	}

	public void weightedVec(Vector wordVec, double weight) {
		double[] wordVector = new double[wordVec.size()];
		for (int i = 0; i < wordVec.size(); i++) {
			wordVector[i] = (wordVec.get(i) * weight);
		}
		wordVec.assign(wordVector);
	}

	public void wordToCount(Vector docVector, Vector wordVec) {
		double[] docVec = new double[docVector.size()];
		for (int i = 0; i < docVector.size(); i++) {
			docVec[i] = (docVector.get(i) + wordVec.get(i));
		}
		docVector.assign(docVec);
	}

	public List<IDocumentVector> transform(Iterator<IDocument> docs) {
		List<IDocumentVector> docVecs = new ArrayList<IDocumentVector>();
		while (docs.hasNext()) {
			docVecs.add(transform((IDocument) docs.next()));
		}
		LOG.info("docVecs"+docVecs.size());
		return docVecs;
	}

	public Vector queryWord(IWord word) {
		float[] wordVecArray = new float[wordVecDim];
		if (word2Vec.getWordVector(word.getWord()) != null)
			wordVecArray = (float[]) word2Vec
					.getWordVector(word.getWord()).clone();
		double[] trans = new double[wordVecArray.length];
		int index = 0;
		for (float value : wordVecArray) {
			trans[(index++)] = value;
		}
		Vector singleWordVec = new DenseVector(trans);
		return singleWordVec;
	}

	public Iterator<IWord> getTopKSimilar(int k, IWord word) {
		return null;
	}

	protected String filterTerms(String segResult) {
		String result = segResult;
		for (WordsFilter filter : this.filtersSet) {
			result = filter.filter(result);
		}
		return result;
	}

	public Iterator<IDocument> getDoucument(String filePath)
			throws IOException {
		BufferedReader ber = new BufferedReader(new InputStreamReader(
				new FileInputStream(filePath), "UTF-8"));
		String line = "";
		Parser p = null;
		List<IDocument> documents = new ArrayList<IDocument>();
		while ((line = ber.readLine()) != null) {
			IDocument document;
			if (line.split("\t").length > 8) {
				p = new HiveWechatParser();
				document = p.parse(line);
			} else {
				p = new LineParser();
				document = p.parse(line);
			}
			if(document.getContent().replaceAll(" ", "").length()>0)
			documents.add(document);
		}
		ber.close();
		LOG.info("文档数量=" + documents.size());
		return documents.iterator();
	}

	public Iterator<IDocument> getDocuments(String filePath) {
		List<IDocument> doucuments = new ArrayList<IDocument>();
		List<File> listFile = new ArrayList<File>();
		File file = new File(filePath);
		if (file.isDirectory()) {
			FileUtil.getFilepath(file, listFile);
			for (File f : listFile) {
				try {
					Iterator<IDocument> temp = getDoucument(
							f.getAbsolutePath());
					while (temp.hasNext())
						doucuments.add((IDocument) temp.next());
//						System.out.println(f.getAbsolutePath());
				} catch (IOException e) {
					LOG.error(e.getMessage());
				}
			}
		}
		return doucuments.iterator();
	}
	// 对文本直接进行向量准换
	public List<IDocumentVector> transformDoc(String doc) {
		List<IDocumentVector> listDocumentVec = new ArrayList<IDocumentVector>();
		Vector docVector = new DenseVector(wordVecDim);
		List<IWord> words = new ArrayList<IWord>();
		// 分词
		String segResult = this.docSeg.segment(doc);
		// 过滤
		segResult = filterTerms(segResult);
		//
		for (String word : segResult.split("\\s+")) {
			IWord mIword = new WordInWord2Vec(word);
			words.add(mIword);
		}
		// 对所有词求和形成文档向量
		for (IWord word : words) {
			Vector wordVec = new DenseVector(wordVecDim);
			if (word2Vec.queryWord(word) != null) {
				wordVec = word2Vec.queryWord(word);
			}
			LOG.info("word="+word.getWord()+"\t"+wordVec.get(0));
//			weightedVec((Vector) wordVec, word.getWeight());
			wordToCount(docVector,  wordVec);
		}
		IDocumentVector doucumentVector = new DocumentVectorImpl(
				String.valueOf(this.cur.getAndIncrement()),
				"unknown", docVector);
		listDocumentVec.add(doucumentVector);
		return listDocumentVec;
	}

	public ModelStatus loadModelFromFile(File f) {
		if (isModelExist()) {
			try (DataInputStream dis = new DataInputStream(
					new BufferedInputStream(new FileInputStream(
							f.getAbsolutePath())))){
				words = dis.readInt();
				size = dis.readInt();

				float vector = 0;

				String key = null;
				float[] value = null;
				for (int i = 0; i < words; i++) {
					double len = 0;
					key = dis.readUTF();
					value = new float[size];
					for (int j = 0; j < size; j++) {
						vector = dis.readFloat();
						len += vector * vector;
						value[j] = vector;
					}

					len = Math.sqrt(len);

					for (int j = 0; j < size; j++) {
						value[j] /= len;
					}
					wordMap.put(key, value);

				}
			} catch (Exception e) {
				LOG.error(
						"Loading model is fault,please checking the model file!");
			}
		} else {
			LOG.error("The model isn't exist");
			return null;
		}
		return null;
	}

	public boolean isModelExist() {
		File file = new File(RESOURCE_MODEL);
		if (file.exists())
			return true;
		else
			return false;
	}

	/**
	 * 得到词向量
	 * 
	 * @param word
	 * @return
	 */
	public float[] getWordVector(String word) {
		return wordMap.get(word);
	}

	public static void main(String[] args) {
		DocumentSegmentation docSeg = new AnsjDocumentAnalyzer();
		Word2VecForTransformingDocs wordVecDocs = getInstance(docSeg);
		ILearnerModel iLearnerModel = wordVecDocs;
		IDocument doc = new Document("112131234", "324324",
				"推动宣传教育宣传");
		IDocumentVector docVec = word2Vec.transform(doc);
		System.out.println();
	}

}